{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Object Following - Live Demo\n",
    "\n",
    "In this notebook we'll show how you can follow an object with JetBot!  We'll use a pre-trained neural network\n",
    "that was trained on the [COCO dataset](http://cocodataset.org) to detect 90 different common objects.  These include\n",
    "\n",
    "* Person (index 0)\n",
    "* Cup (index 47)\n",
    "\n",
    "and many others (you can check [this file](https://github.com/tensorflow/models/blob/master/research/object_detection/data/mscoco_complete_label_map.pbtxt) for a full list of class indices).  The model is sourced from the [TensorFlow object detection API](https://github.com/tensorflow/models/tree/master/research/object_detection),\n",
    "which provides utilities for training object detectors for custom tasks also!  Once the model is trained, we optimize it using NVIDIA TensorRT on the Jetson Nano.\n",
    "\n",
    "This makes the network very fast, capable of real-time execution on Jetson Nano!  We won't run through all of the training and optimization steps in this notebook though.\n",
    "\n",
    "Anyways, let's get started.  First, we'll want to import the ``ObjectDetector`` class which takes our pre-trained SSD engine."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compute detections on single camera image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "slideshow": {
     "slide_type": "-"
    }
   },
   "outputs": [],
   "source": [
    "from jetbot import ObjectDetector\n",
    "\n",
    "model = ObjectDetector('ssd_mobilenet_v2_coco.engine')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Internally, the ``ObjectDetector`` class uses the TensorRT Python API to execute the engine that we provide.  It also takes care of preprocessing the input to the neural network, as\n",
    "well as parsing the detected objects.  Right now it will only work for engines created using the ``jetbot.ssd_tensorrt`` package. That package has the utilities for converting\n",
    "the model from the TensorFlow object detection API to an optimized TensorRT engine.\n",
    "\n",
    "Next, let's initialize our camera.  Our detector takes 300x300 pixel input, so we'll set this when creating the camera.\n",
    "\n",
    "> Internally, the Camera class uses GStreamer to take advantage of Jetson Nano's Image Signal Processor (ISP).  This is super fast and offloads\n",
    "> a lot of the resizing computation from the CPU. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from jetbot import Camera\n",
    "\n",
    "camera = Camera.instance(width=300, height=300)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let's execute our network using some camera input.  By default the ``ObjectDetector`` class expects ``bgr8`` format that the camera produces.  However,\n",
    "you could override the default pre-processing function if your input is in a different format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "detections = model(camera.value)\n",
    "\n",
    "print(detections)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If there are any COCO objects in the camera's field of view, they should now be stored in the ``detections`` variable."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Display detections in text area\n",
    "\n",
    "We'll use the code below to print out the detected objects."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import display\n",
    "import ipywidgets.widgets as widgets\n",
    "#\n",
    "detections_widget = widgets.Textarea()\n",
    "detections_widget.value = str(detections)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You should see the label, confidence, and bounding box of each object detected in each image.  There's only one image (our camera) in this example. \n",
    "\n",
    "\n",
    "To print just the first object detected in the first image, we could call the following\n",
    "\n",
    "> This may throw an error if no objects are detected"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_number = 0\n",
    "object_number = 0\n",
    "\n",
    "print(detections[image_number][object_number])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Control robot to follow central object\n",
    "\n",
    "Now we want our robot to follow an object of the specified class.  To do this we'll do the following\n",
    "\n",
    "1.  Detect objects matching the specified class\n",
    "2.  Select object closest to center of camera's field of vision, this is the 'target' object\n",
    "3.  Steer robot towards target object, otherwise wander\n",
    "4.  If we're blocked by an obstacle, turn left\n",
    "\n",
    "We'll also create some widgets that we'll use to control the target object label, the robot speed, and\n",
    "a \"turn gain\", that will control how fast the robot turns based off the distance between the target object\n",
    "and the center of the robot's field of view. \n",
    "\n",
    "\n",
    "First, let's load our collision detection model.  The pre-trained model is stored in this directory as a convenience, but if you followed\n",
    "the collision avoidance example you may want to use that model if it's better tuned for your robot's environment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torch.nn.functional as F\n",
    "import cv2\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "collision_model = torchvision.models.alexnet(pretrained=False)\n",
    "collision_model.classifier[6] = torch.nn.Linear(collision_model.classifier[6].in_features, 2)\n",
    "collision_model.load_state_dict(torch.load('../collision_avoidance/best_model.pth'))\n",
    "device = torch.device('cuda')\n",
    "collision_model = collision_model.to(device)\n",
    "\n",
    "mean = 255.0 * np.array([0.485, 0.456, 0.406])\n",
    "stdev = 255.0 * np.array([0.229, 0.224, 0.225])\n",
    "\n",
    "normalize = torchvision.transforms.Normalize(mean, stdev)\n",
    "\n",
    "def preprocess(camera_value):\n",
    "    global device, normalize\n",
    "    x = camera_value\n",
    "    x = cv2.resize(x, (224, 224))\n",
    "    x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)\n",
    "    x = x.transpose((2, 0, 1))\n",
    "    x = torch.from_numpy(x).float()\n",
    "    x = normalize(x)\n",
    "    x = x.to(device)\n",
    "    x = x[None, ...]\n",
    "    return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Great, now let's initialize our robot so we can control the motors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from jetbot import Robot\n",
    "import math\n",
    "\n",
    "robot = Robot()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, let's display all the control widgets and connect the network execution function to the camera updates."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def steering(x, y): \n",
    "    #script from stackexchange of user Pedro Werneck \n",
    "    #https://electronics.stackexchange.com/questions/19669/algorithm-for-mixing-2-axis-analog-input-to-control-a-differential-motor-drive\n",
    "    # convert to polar\n",
    "    r = math.hypot(x, y)\n",
    "    t = math.atan2(y, x)\n",
    "\n",
    "    # rotate by 45 degrees\n",
    "    t += math.pi / -4.0\n",
    "\n",
    "    # back to cartesian\n",
    "    left = r * math.cos(t)\n",
    "    right = r * math.sin(t)\n",
    "\n",
    "    # rescale the new coords\n",
    "    left = left * math.sqrt(2)\n",
    "    right = right * math.sqrt(2)\n",
    "\n",
    "    # clamp to -1/+1 useing the speed widget slider Max 1 but movement at 0.2\n",
    "    scalefactor= speed_widget.value\n",
    "    left = max(scalefactor*-1.0, min(left, scalefactor))\n",
    "    right = max(scalefactor*-1.0, min(right, scalefactor))\n",
    "    \n",
    "    #gamma correction for response sensitivity of joystick or center_x value while turning : TB\n",
    "    gamma=turn_gain_widget.value  #using slider for joystick 1-4, for object recognition 2-40  \n",
    "    if left <0.0 :\n",
    "        left= -1.0* (((abs(left)/scalefactor)**(1.0/gamma))*scalefactor)\n",
    "    else:\n",
    "        left= ((abs(left)/scalefactor)**(1.0/gamma))*scalefactor\n",
    "       \n",
    "    if right <0.0:\n",
    "        right= -1.0*(((abs(right)/scalefactor)**(1.0/gamma))*scalefactor)\n",
    "    else:\n",
    "        right= ((abs(right)/scalefactor)**(1.0/gamma))*scalefactor\n",
    "            \n",
    "    return left, right"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#cell T1\n",
    "from jetbot import bgr8_to_jpeg\n",
    "import traitlets\n",
    "import IPython\n",
    "import time\n",
    "#import random\n",
    "\n",
    "global i, ii, oldcenter,w,countnoobject\n",
    "\n",
    "blocked_widget = widgets.FloatSlider(min=0.0, max=1.0, value=0.0, description='blocked')\n",
    "image_widget = widgets.Image(format='jpeg', width=300, height=300)\n",
    "label_widget = widgets.IntText(value=1, description='tracked label')\n",
    "#chose an item: water bottle label 44: click on the very right of the box to activate label number\n",
    "collisionchoice_widget=widgets.IntText(value= 1, description='ColA.N=0Y=1)')\n",
    "#0 no collison avoidance: will stop, can be used to test object targeting without collison avoidance interference\n",
    "#1 will include search mode with collison avoidance, need to click \n",
    "\n",
    "speed_widget = widgets.FloatSlider(value=0.3, min=0.05, max=1.0, step=0.001, description='speed')\n",
    "#TB higher speed requires smaller turn_gain values: 2.5 for speed 0.22, around 2 for speed 0.4\n",
    "turn_gain_widget = widgets.FloatSlider(value=2.50, min=0.50, step=0.001, max=140.0, description='turn sensitivity')\n",
    "#TB value different for different forward speed, but very small differences\n",
    "motoradjustment_widget = widgets.FloatSlider(value=0.04, min=0.00, max=0.2, step=0.0001, description='motoradjustment')\n",
    "\n",
    "#TB object_stop_threshold-----------------------------------------------------------------------\n",
    "stop_threshold = widgets.FloatSlider(min=-0.2, max=1.2, value=0.8, step=0.01, description='object_stop_threshold')\n",
    "stopduration_slider= widgets.IntSlider(min=1, max=1000, step=1, value=40, description='Manu. time stop') #anti-collision stop time ~ frames*0.1s if 10 Frames per seconds\n",
    "block_threshold= widgets.FloatSlider(min=0, max=1, step=0.1, value=0.85, description='Manu. bl threshold') #anti-collision block probability\n",
    "\n",
    "#widgets for visualization of some values, from data_collection TB\n",
    "x=0.0\n",
    "i=0.0\n",
    "ii=00.0\n",
    "w=0.0\n",
    "count_stops=0\n",
    "countnoobjects=0.0\n",
    "oldcenter=0.0\n",
    "go_on=1\n",
    "x=0\n",
    "y=0\n",
    "button_layout = widgets.Layout(width='140px', height='64px') #TB\n",
    "\n",
    "centerY = widgets.FloatText(layout=button_layout, value=x, description='centerY') #TB\n",
    "centerX = widgets.FloatText(layout=button_layout, value=x, description='centerX') #TB\n",
    "value1 = widgets.FloatText(layout=button_layout, value=x, description='x') #TB\n",
    "value2 = widgets.FloatText(layout=button_layout, value=x, description='stop botton') #TB\n",
    "lostImages = widgets.FloatText(layout=button_layout, value=x, description='lost Object') #TB\n",
    "countframes= widgets.FloatText(layout=button_layout, value=x, description='Nr. No-Object (Search>500)') #TB\n",
    "motorleft = widgets.FloatText(layout=button_layout, value=x, description=\"left motor\") #TB\n",
    "motorright = widgets.FloatText(layout=button_layout, value=x, description=\"right motor\") #TB\n",
    "\n",
    "\n",
    "display(widgets.VBox([\n",
    "    widgets.HBox([image_widget, blocked_widget,stopduration_slider,block_threshold]),\n",
    "    widgets.HBox([widgets.Label(value=\"Track Label Number e.g. 44 for bottle....................................Activate Collision Avoidance 0: No, 1: Yes\")]),\n",
    "    widgets.HBox([label_widget, collisionchoice_widget]),\n",
    "    speed_widget,\n",
    "    turn_gain_widget,\n",
    "    motoradjustment_widget,\n",
    "    stop_threshold\n",
    "]))\n",
    "\n",
    "d2 = IPython.display.display(\"\", display_id=2)\n",
    "\n",
    "#Display widgets from cell T1\n",
    "display(widgets.HBox([centerY, centerX, value1, value2]))\n",
    "display(widgets.HBox([motorleft, motorright, lostImages, countframes]))\n",
    "\n",
    "#TB traitlets could be used or direct transfer to motors, what is faster?\n",
    "#left_link = traitlets.dlink((motorleft, 'value'), (robot.left_motor, 'value'))\n",
    "#right_link = traitlets.dlink((motorright, 'value'), (robot.right_motor, 'value'))\n",
    "\n",
    "width = int(image_widget.width)\n",
    "height = int(image_widget.height)\n",
    "\n",
    "def detection_center(detection):\n",
    "    \"\"\"Computes the center x, y coordinates of the object\"\"\"\n",
    "    bbox = detection['bbox']\n",
    "    center_x = (bbox[0] + bbox[2]) / 2.0 - 0.5\n",
    "    center_y = (bbox[1] + bbox[3]) / 2.0 - 0.5\n",
    "    botton=bbox[3]\n",
    "    centerY.value=center_y\n",
    "    centerX.value=center_x\n",
    "    return (center_x, center_y, botton)\n",
    "    \n",
    "def norm(vec):\n",
    "    \"\"\"Computes the length of the 2D vector\"\"\"\n",
    "    return np.sqrt(vec[0]**2 + vec[1]**2)\n",
    "\n",
    "def closest_detection(detections):\n",
    "    \"\"\"Finds the detection closest to the image center\"\"\"\n",
    "    closest_detection = None\n",
    "    for det in detections:\n",
    "        center = detection_center(det)\n",
    "        if closest_detection is None:\n",
    "            closest_detection = det\n",
    "        elif norm(detection_center(det)) < norm(detection_center(closest_detection)):\n",
    "            closest_detection = det\n",
    "    return closest_detection\n",
    "\n",
    "def collisionavoidance(image):\n",
    "    # execute collision model to determine if blocked\n",
    "    collision_output = collision_model(preprocess(image)).detach().cpu()\n",
    "    prob_blocked = float(F.softmax(collision_output.flatten(), dim=0)[0])\n",
    "    blocked_widget.value = prob_blocked\n",
    "    \n",
    "    # turn left if blocked\n",
    "    if prob_blocked > block_threshold.value:\n",
    "        robot.left(0.3) \n",
    "        motorright.value=0.3\n",
    "        motorleft.value=0\n",
    "    else:\n",
    "        y=speed_widget.value\n",
    "        leftnew, rightnew= steering(0,y) #exchange left right in case for wrong-side steering\n",
    "        motorright.value= round(rightnew, 2) #exchange left right in case for wrong-side steering\n",
    "        motorleft.value= round(leftnew+motoradjustment_widget.value, 2)\n",
    "        #motoradjustment value important to keep bot driving straight, small offset-values like 0.05\n",
    "        robot.right_motor.value=motorright.value\n",
    "        robot.left_motor.value=motorleft.value\n",
    "        \n",
    "    return\n",
    "        \n",
    "def execute(change):\n",
    "    global i,ii,oldcenter,w, countnoobjects,stop_threshold,stopduration_slider, count_stops, go_on,x,y\n",
    "    \n",
    "    t1 = time.time()\n",
    "    stop_time=stopduration_slider.value\n",
    "    image = change['new']\n",
    "            \n",
    "    # compute all detected objects\n",
    "    detections = model(image)\n",
    "    \n",
    "    # draw all detections on image\n",
    "    for det in detections[0]:\n",
    "        bbox = det['bbox']\n",
    "        cv2.rectangle(image, (int(width * bbox[0]), int(height * bbox[1])), (int(width * bbox[2]), int(height * bbox[3])), (255, 0, 0), 2)\n",
    "    \n",
    "    # select detections that match selected class label\n",
    "    matching_detections = [d for d in detections[0] if d['label'] == int(label_widget.value)]\n",
    "    \n",
    "    #get detection closest to center of field of view and draw it\n",
    "    #here start to check for previous postive detections? TB to avoid object skipping runave 4\n",
    "    det = closest_detection(matching_detections)\n",
    "    if det is not None:\n",
    "        bbox = det['bbox']\n",
    "        cv2.rectangle(image, (int(width * bbox[0]), int(height * bbox[1])), (int(width * bbox[2]), int(height * bbox[3])), (0, 255, 0), 5)\n",
    "      \n",
    "    # otherwise go stop or go forward if no target detected\n",
    "    #TB decide Object is lost or just not detected ii: needs to be adjusted\n",
    "    if det is None:\n",
    "        if i>=1.0 and 1.0<=ii<=8.0:\n",
    "            w+=1.0 #if only few images without objects, just stop and wait for object appears\n",
    "            lostImages.value=w\n",
    "            if ii>1:                \n",
    "                motorright.value=0.0\n",
    "                motorleft.value= 0.0\n",
    "                robot.left_motor.value=0.0\n",
    "                robot.right_motor.value=0.0\n",
    "            countnoobjects=0.0\n",
    "        else:    \n",
    "            i=0.0\n",
    "            ii=0.0\n",
    "            w=0.0\n",
    "            countnoobjects+=1.0\n",
    "            lostImages.value=w\n",
    "            countframes.value=countnoobjects\n",
    "            #activation of search mode and using collision_avoidance trained base\n",
    "            if int(collisionchoice_widget.value)==1:\n",
    "                if countnoobjects>100: #number of frames/time to wait before going into search mode Object detection about 16 Fps\n",
    "                    collisionavoidance(image)\n",
    "            else:\n",
    "                robot.forward(float(0)) \n",
    "                \n",
    "        ii+=float(1.0)\n",
    "                    \n",
    "    # otherwsie steer towards target\n",
    "    else:\n",
    "        # move robot forward and steer proportional target's x-distance from center\n",
    "        if w>=1.0 :\n",
    "            i+=1.0 #to count the object event per image\n",
    "            ii=1.0\n",
    "        else:\n",
    "            i+=1.0\n",
    "            \n",
    "        center = detection_center(det)\n",
    "        \n",
    "        #------add stop before reaching object\n",
    "        stop_before_object=center[2]\n",
    "        value2.value=stop_before_object #show box center value, adjusted\n",
    "        if go_on==1:      \n",
    "            if stop_before_object>=stop_threshold.value:\n",
    "                y=0.0 #set speed zero\n",
    "                x=0.0\n",
    "                count_stops +=1\n",
    "                go_on=2\n",
    "            else:\n",
    "                go_on=1\n",
    "                count_stops=0\n",
    "                y=speed_widget.value\n",
    "                x=center[0]/4.0 #TB reduce the values for x coordinates to make steering less sensitive\n",
    "        else:\n",
    "            count_stops += 1\n",
    "            if count_stops<stop_time:\n",
    "                y=0 #speed\n",
    "                x=0\n",
    "            else:\n",
    "                go_on=1\n",
    "                count_stops=0\n",
    "                y=speed_widget.value\n",
    "      \n",
    "    \n",
    "        #------\n",
    "        value1.value=x #show box botton value, \n",
    "        leftnew, rightnew= steering(x,y) #exchange left right in case for wrong-side steering\n",
    "        motorright.value= round(rightnew, 2) #exchange left right in case for wrong-side steering\n",
    "        motorleft.value= round(leftnew+motoradjustment_widget.value, 2)\n",
    "        #motoradjustment value important to keep bot driving straight, small offset-values like 0.05\n",
    "        robot.right_motor.value=motorright.value\n",
    "        robot.left_motor.value=motorleft.value     \n",
    "          \n",
    "    \n",
    "    # update image widget   \n",
    "    image_widget.value = bgr8_to_jpeg(image)\n",
    "    #timer\n",
    "    t2 = time.time()\n",
    "    s = f\"\"\"{int(1/(t2-t1))} FPSS\"\"\"\n",
    "    d2.update(IPython.display.HTML(s) )\n",
    "    \n",
    "execute({'new': camera.value})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Call the block below to connect the execute function to each camera frame update."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "camera.unobserve_all()\n",
    "camera.observe(execute, names='value')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Awesome!  If the robot is not blocked you should see boxes drawn around the detected objects in blue.  The target object (which the robot follows) will be displayed in green.\n",
    "\n",
    "The robot should steer towards the target when it is detected.  If it is blocked by an object it will simply turn left.\n",
    "\n",
    "You can call the code block below to manually disconnect the processing from the camera and stop the robot."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "camera.unobserve_all()\n",
    "time.sleep(1.0)\n",
    "robot.stop()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "camera_link.unlink()  # TB 12072020 don't stream to browser (will still run camera)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
